// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_LAYERNORMUNIT_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_LAYERNORMUNIT_H_

#include <ac_math.h>
#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"
#include "src/VectorUtils.h"

template <typename DTYPE, int WIDTH>
SC_MODULE(LayerNormUnit) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<VectorParams> CCS_INIT_S1(paramsIn);
  Connections::In<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(vectorIn);
  Connections::Out<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(vectorOut);

  SC_CTOR(LayerNormUnit) {
    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

#pragma hls_design interface ccore
#pragma hls_pipeline_init_interval 1
  void custom_sqrt(ac::bfloat16 & input, ac::bfloat16 & output) {
    typedef ac_fixed<10, 6, false, AC_TRN, AC_SAT> fixedT;
    fixedT input_fixed_point =
        input.convert_to_ac_fixed<10, 6, false, AC_TRN, AC_SAT>();
    fixedT output_fixed_point;

    ac_math::ac_sqrt_pwl(input_fixed_point, output_fixed_point);

    output = ac::bfloat16(output_fixed_point);
  }

  void run() {
    paramsIn.Reset();
    vectorIn.Reset();
    vectorOut.Reset();

    wait();

    while (true) {
      VectorParams params = paramsIn.Pop();

      // calculate sum
      DTYPE sum = 0;

      int Y = params.loops[1][params.yLoopIndex[1]];
      int X = params.loops[1][params.xLoopIndex[1]];

#pragma hls_pipeline_init_interval 1
      for (int y = 0; y < Y; y++) {
        for (int x = 0; x < X; x++) {
          Pack1D<DTYPE, WIDTH> vec = vectorIn.Pop();
          reduceSum<DTYPE, WIDTH>(vec, sum);
        }
      }

      // calculate mean
      DTYPE mean = sum / static_cast<DTYPE>(Y * X);

      // calculate variance
      sum = 0;
#pragma hls_pipeline_init_interval 1
      for (int y = 0; y < Y; y++) {
        for (int x = 0; x < X; x++) {
          Pack1D<DTYPE, WIDTH> vec = vectorIn.Pop();
          vectorSub<DTYPE, WIDTH>(vec, mean);
          vectorSquare<DTYPE, WIDTH>(vec);
          reduceSum<DTYPE, WIDTH>(vec, sum);
        }
      }
      DTYPE var = sum / static_cast<DTYPE>(Y * X);
      DTYPE std_dev;
      custom_sqrt(var, std_dev);

      // calculate norm
#pragma hls_pipeline_init_interval 1
      for (int y = 0; y < Y; y++) {
        for (int x = 0; x < X; x++) {
          Pack1D<DTYPE, WIDTH> vec = vectorIn.Pop();
          vectorSub<DTYPE, WIDTH>(vec, mean);
          vectorDiv<DTYPE, WIDTH>(vec, std_dev);
          vectorOut.Push(vec);
        }
      }
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_LAYERNORMUNIT_H_
